forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 73
[release/2.8] [Bugfix][Inductor] Fix dependency list merged incorrectly for a custo… #2419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…m op with multiple mutated inputs and None return type. (pytorch#157133) This is an attempt to fix a memory allocation issue when using `torch.compile` with a custom layernorm kernel in vllm: ```C++ // In-place fused Add and RMS Normalization. ops.def( "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ``` We observed abnormal extra memory allocations with this op enabled using `torch.compile`: <img width="738" alt="{374E9FCF-FB46-4750-8B60-D31E3ADCE00A}" src="https://github.com/user-attachments/assets/6c45e1aa-ccde-4c56-99dc-bf4776d699d5" /> and without this op: <img width="738" alt="{9BB08EFE-FFE3-4D06-82C0-C70BBE6ADD56}" src="https://github.com/user-attachments/assets/56e2ee43-ab87-492d-834c-69e9cafbb0df" /> After investigation, we found that this is because the compiler considers the two buffers for the two mutated inputs `Tensor input` and `Tensor residual` should share a same dependency list, which makes it can not reuse the buffer of `Tensor input`. ``` buf1.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf16.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] ``` ``` op13: ExternKernelSchedulerNode(FallbackKernel) op13.writes = [ StarDep(name='buf17', mode=None), StarDep(name='buf18', mode=None), StarDep(name='buf19', mode=None)] op13.unmet_dependencies = [ StarDep(name='buf13', mode=None), StarDep(name='buf16', mode=None), WeakDep(name='buf11', mutating_buf='buf18'), WeakDep(name='buf12', mutating_buf='buf18'), WeakDep(name='buf13', mutating_buf='buf18'), WeakDep(name='buf2', mutating_buf='buf18'), WeakDep(name='buf3', mutating_buf='buf18')] op13.met_dependencies = [StarDep(name='arg11_1', mode=None)] op13.outputs = [ buf17: FallbackKernel buf17.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf17.aliases = ['buf16', 'buf1'] buf17.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf18: MutationOutput buf18.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf18.mutations = ['buf16'] buf18.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op14'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=True), ] buf19: MutationOutput buf19.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf19.mutations = ['buf1'] buf19.users = [NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False)] ] op13.node.kernel = torch.ops._C.fused_add_rms_norm.default ``` Here we can see `buf16` shares the same dependency list with `buf1` because `buf16` and `buf1` are in the aliases list of `buf17`. This is incorrect since those two are two separate tensors. And this makes the compiler could not reuse `buf16` for subsequent ops. Pull Request resolved: pytorch#157133 Approved by: https://github.com/jansel (cherry picked from commit 02724b5)
pragupta
added a commit
that referenced
this pull request
Jul 29, 2025
…ly for a custo… (#2419) …m op with multiple mutated inputs and None return type. (pytorch#157133) This is an attempt to fix a memory allocation issue when using `torch.compile` with a custom layernorm kernel in vllm: ```C++ // In-place fused Add and RMS Normalization. ops.def( "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ``` We observed abnormal extra memory allocations with this op enabled using `torch.compile`: <img width="738" alt="{374E9FCF-FB46-4750-8B60-D31E3ADCE00A}" src="https://github.com/user-attachments/assets/6c45e1aa-ccde-4c56-99dc-bf4776d699d5" /> and without this op: <img width="738" alt="{9BB08EFE-FFE3-4D06-82C0-C70BBE6ADD56}" src="https://github.com/user-attachments/assets/56e2ee43-ab87-492d-834c-69e9cafbb0df" /> After investigation, we found that this is because the compiler considers the two buffers for the two mutated inputs `Tensor input` and `Tensor residual` should share a same dependency list, which makes it can not reuse the buffer of `Tensor input`. ``` buf1.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf16.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] ``` ``` op13: ExternKernelSchedulerNode(FallbackKernel) op13.writes = [ StarDep(name='buf17', mode=None), StarDep(name='buf18', mode=None), StarDep(name='buf19', mode=None)] op13.unmet_dependencies = [ StarDep(name='buf13', mode=None), StarDep(name='buf16', mode=None), WeakDep(name='buf11', mutating_buf='buf18'), WeakDep(name='buf12', mutating_buf='buf18'), WeakDep(name='buf13', mutating_buf='buf18'), WeakDep(name='buf2', mutating_buf='buf18'), WeakDep(name='buf3', mutating_buf='buf18')] op13.met_dependencies = [StarDep(name='arg11_1', mode=None)] op13.outputs = [ buf17: FallbackKernel buf17.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf17.aliases = ['buf16', 'buf1'] buf17.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf18: MutationOutput buf18.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf18.mutations = ['buf16'] buf18.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op14'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=True), ] buf19: MutationOutput buf19.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf19.mutations = ['buf1'] buf19.users = [NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False)] ] op13.node.kernel = torch.ops._C.fused_add_rms_norm.default ``` Here we can see `buf16` shares the same dependency list with `buf1` because `buf16` and `buf1` are in the aliases list of `buf17`. This is incorrect since those two are two separate tensors. And this makes the compiler could not reuse `buf16` for subsequent ops. Pull Request resolved: pytorch#157133 Approved by: https://github.com/jansel (cherry picked from commit 02724b5) Fixes #ISSUE_NUMBER Co-authored-by: charlifu <[email protected]>
tvukovic-amd
pushed a commit
that referenced
this pull request
Aug 20, 2025
…ly for a custo… (#2419) …m op with multiple mutated inputs and None return type. (pytorch#157133) This is an attempt to fix a memory allocation issue when using `torch.compile` with a custom layernorm kernel in vllm: ```C++ // In-place fused Add and RMS Normalization. ops.def( "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ``` We observed abnormal extra memory allocations with this op enabled using `torch.compile`: <img width="738" alt="{374E9FCF-FB46-4750-8B60-D31E3ADCE00A}" src="https://github.com/user-attachments/assets/6c45e1aa-ccde-4c56-99dc-bf4776d699d5" /> and without this op: <img width="738" alt="{9BB08EFE-FFE3-4D06-82C0-C70BBE6ADD56}" src="https://github.com/user-attachments/assets/56e2ee43-ab87-492d-834c-69e9cafbb0df" /> After investigation, we found that this is because the compiler considers the two buffers for the two mutated inputs `Tensor input` and `Tensor residual` should share a same dependency list, which makes it can not reuse the buffer of `Tensor input`. ``` buf1.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf16.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] ``` ``` op13: ExternKernelSchedulerNode(FallbackKernel) op13.writes = [ StarDep(name='buf17', mode=None), StarDep(name='buf18', mode=None), StarDep(name='buf19', mode=None)] op13.unmet_dependencies = [ StarDep(name='buf13', mode=None), StarDep(name='buf16', mode=None), WeakDep(name='buf11', mutating_buf='buf18'), WeakDep(name='buf12', mutating_buf='buf18'), WeakDep(name='buf13', mutating_buf='buf18'), WeakDep(name='buf2', mutating_buf='buf18'), WeakDep(name='buf3', mutating_buf='buf18')] op13.met_dependencies = [StarDep(name='arg11_1', mode=None)] op13.outputs = [ buf17: FallbackKernel buf17.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf17.aliases = ['buf16', 'buf1'] buf17.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op2'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op9'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op13'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=False), ] buf18: MutationOutput buf18.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf18.mutations = ['buf16'] buf18.users = [ NodeUser(node=ExternKernelSchedulerNode(name='op14'), can_inplace=False, is_weak=False), NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op24'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op31'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op35'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op42'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op46'), can_inplace=False, is_weak=True), NodeUser(node=ExternKernelSchedulerNode(name='op53'), can_inplace=False, is_weak=True), ] buf19: MutationOutput buf19.layout = NoneLayout(device=device(type='cuda', index=0), size=[0], stride=[0]) buf19.mutations = ['buf1'] buf19.users = [NodeUser(node=ExternKernelSchedulerNode(name='op20'), can_inplace=False, is_weak=False)] ] op13.node.kernel = torch.ops._C.fused_add_rms_norm.default ``` Here we can see `buf16` shares the same dependency list with `buf1` because `buf16` and `buf1` are in the aliases list of `buf17`. This is incorrect since those two are two separate tensors. And this makes the compiler could not reuse `buf16` for subsequent ops. Pull Request resolved: pytorch#157133 Approved by: https://github.com/jansel (cherry picked from commit 02724b5) Fixes #ISSUE_NUMBER Co-authored-by: charlifu <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
…m op with multiple mutated inputs and None return type. (pytorch#157133)
This is an attempt to fix a memory allocation issue when using
torch.compile
with a custom layernorm kernel in vllm:We observed abnormal extra memory allocations with this op enabled using
and without this op:

torch.compile
:After investigation, we found that this is because the compiler considers the two buffers for the two mutated inputs
Tensor input
andTensor residual
should share a same dependency list, which makes it can not reuse the buffer ofTensor input
.Here we can see
buf16
shares the same dependency list withbuf1
becausebuf16
andbuf1
are in the aliases list ofbuf17
. This is incorrect since those two are two separate tensors. And this makes the compiler could not reusebuf16
for subsequent ops.Pull Request resolved: pytorch#157133
Approved by: https://github.com/jansel
(cherry picked from commit 02724b5)
Fixes #ISSUE_NUMBER